
import numpy as np
from scipy.optimize import bisect
import matplotlib.pyplot as plt
from astropy import units as u
from astropy.time import Time
from netorchestr.envir.base import OModule
from netorchestr.envir.node.ue import UeBase
from netorchestr.envir.node.ground import GroundServerBase
from netorchestr.envir.node.uav import UavServerBase
from netorchestr.envir.node.container import VnfBase
from netorchestr.eventlog import OLogItem

from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from netorchestr.envir.base import ONet

class ControllerGlobal4Ue(OModule):
    def __init__(self, name:str, net:"ONet"):
        super().__init__(name)
        
        self.net = net
        self.logger = net.logger
        self.scheduler = net.scheduler
        
        # 暂用测试数据
        
        # region 定义服务功能链路 1
        self.s1_aim_delay = 6
        self.s1_aim_overrate = 0.1
        self.s1_arrive_rate = 1
        self.s1_actural_delay = []
        self.s1_predict_delay = []
        self.s1_kff = 0.004

        self.s1_kp = 0.004
        self.s1_ki = 0.0001
        self.s1_kd = 0.0001
        self.s1_i_val = 0
        self.s1_err_prev = 0

        self.s1_process_rate_init = 2
        # self.s1_process_rate_init = self.netcalculate_inverse(arrive_curve={'rho':self.s1_arrive_rate,'sigma':0}, 
        #                                                       aim_delay=10, 
        #                                                       aim_overrate=self.s1_aim_overrate, 
        #                                                       theta=0.1, 
        #                                                       nodenum=3)
        self.s1v1_process_rate = [self.s1_process_rate_init*0.5]
        self.s1v2_process_rate = [self.s1_process_rate_init*0.5]
        self.s1v3_process_rate = [self.s1_process_rate_init*0.5]
        self.s1v1_queue_len = [0]
        self.s1v2_queue_len = [0]
        self.s1v3_queue_len = [0]
        self.s1_revenue_list = [0]
        self.s1_cost_list = [0]
        # endregion
        
        # region 定义服务功能链路 2
        self.s2_aim_delay = 4
        self.s2_aim_overrate = 0.1
        self.s2_arrive_rate = 1
        self.s2_actural_delay = []
        self.s2_predict_delay = []
        self.s2_kff = 0.004

        self.s2_kp = 0.004
        self.s2_ki = 0.0001
        self.s2_kd = 0.0001
        self.s2_i_val = 0
        self.s2_err_prev = 0

        self.s2_process_init = 2
        # self.s2_process_init = self.netcalculate_inverse(arrive_curve={'rho':self.s2_arrive_rate,'sigma':0}, 
        #                                                  aim_delay=5, 
        #                                                  aim_overrate=self.s2_aim_overrate, 
        #                                                  theta=0.1, 
        #                                                  nodenum=2)
        self.s2v1_process_rate = [self.s2_process_init*0.5]
        self.s2v2_process_rate = [self.s2_process_init*0.5]
        self.s2v1_queue_len = [0]
        self.s2v2_queue_len = [0]
        self.s2_revenue_list = [0]
        self.s2_cost_list = [0]
        
        # endregion
        
    def initialize(self):
        self.logger.debug(f"{self.scheduler.now}: ControllerGlobal initialize")
        
        vnf1:VnfBase = [module.vnfList[0] for module in self.net.modules 
                        if isinstance(module, (UavServerBase, GroundServerBase)) and module.name == "Vim1"][0]
        vnf2:VnfBase = [module.vnfList[0] for module in self.net.modules 
                        if isinstance(module, (UavServerBase, GroundServerBase)) and module.name == "Vim2"][0]
        vnf3:VnfBase = [module.vnfList[0] for module in self.net.modules 
                        if isinstance(module, (UavServerBase, GroundServerBase)) and module.name == "Vim3"][0]
        
        vnf1.appLayer.process_rate = self.s1v1_process_rate[-1] + self.s2v1_process_rate[-1]
        vnf2.appLayer.process_rate = self.s1v2_process_rate[-1] + self.s2v2_process_rate[-1]
        vnf3.appLayer.process_rate = self.s1v3_process_rate[-1]
        
        self.scheduler.process(self.controller_process_sfc_1())
        self.scheduler.process(self.controller_process_sfc_2())
        
    def calculate_revenue(self, performance_record:list[float], aim_delay:float, aim_overrate:float, arrive_rate:float):
        R_D = 1 - performance_record[-1]/aim_delay
        watch_window = 100
        delay_over_num = 0
        if len(performance_record) > watch_window:
            for i in range(len(performance_record)-watch_window, len(performance_record)):
                if performance_record[i] > aim_delay:
                    delay_over_num += 1
            P = delay_over_num/watch_window
        else:
            for i in range(len(performance_record)):
                if performance_record[i] > aim_delay:
                    delay_over_num += 1
            P = delay_over_num/len(performance_record)
        R_P = np.exp(-P/aim_overrate)
        R = arrive_rate * R_D * R_P
        return R 
        
    def netcalculate(self, service_curve:dict[str:float], arrive_curve:dict[str:float], aim_overrate:float, theta:float, nodenum:int):
        if service_curve['rho'] < arrive_curve['rho']:
            return np.inf, np.inf
        elif service_curve['rho'] == arrive_curve['rho']:
            gamma = 1
        else:
            gamma = 1 + 1 / (1 - np.exp(-theta * (service_curve['rho']-arrive_curve['rho'])))
        
        tau_left = service_curve['sigma'] / service_curve['rho']
        tau_right = (nodenum * np.log(gamma) - np.log(aim_overrate)) / (theta * service_curve['rho'])
        tau = tau_left + tau_right
        block = service_curve['sigma'] + (np.log(gamma) - np.log(aim_overrate))/theta
        return tau, block
        
    def netcalculate_inverse(self,arrive_curve:dict[str:float], aim_delay:float, aim_overrate:float, theta:float, nodenum:int):
        def equation(x, y_target):
            numerator = (nodenum * np.log(1 + 1 / (1 - np.exp(-theta * (x - arrive_curve['rho']))))) - np.log(aim_overrate)
            denominator = theta * x
            return numerator/denominator - y_target
        
        rho_left = arrive_curve['rho'] + 0.001
        rho_right = arrive_curve['rho'] * 10
        rho = bisect(lambda x: equation(x, aim_delay), rho_left, rho_right)
        return rho
    
    def controller_process_sfc_1(self):
        while True:
            yield self.scheduler.timeout(self.s1_aim_delay)
            
            ue2:UeBase = [module for module in self.net.modules if isinstance(module, UeBase) if module.name == 'Ue2'][0]
            
            if len(ue2.appLayer.performance_record) == 0:
                continue

            tau, block = self.netcalculate(service_curve = {'rho':min(self.s1v1_process_rate[-1],self.s1v2_process_rate[-1],self.s1v3_process_rate[-1]),'sigma':0},
                                           arrive_curve = {'rho':self.s1_arrive_rate,'sigma':0},
                                           aim_overrate = self.s1_aim_overrate, theta = 0.1, nodenum = 3)
            predict_delay = tau # + link1.delay + link2.delay + link3.delay + link4.delay
            self.s1_predict_delay.append(predict_delay)
            self.s1_actural_delay.append(ue2.appLayer.performance_record[-1])
            
            vnf1:VnfBase = [module.vnfList[0] for module in self.net.modules 
                            if isinstance(module, (UavServerBase, GroundServerBase)) and module.name == "Vim1"][0]
            vnf2:VnfBase = [module.vnfList[0] for module in self.net.modules 
                            if isinstance(module, (UavServerBase, GroundServerBase)) and module.name == "Vim2"][0]
            vnf3:VnfBase = [module.vnfList[0] for module in self.net.modules 
                            if isinstance(module, (UavServerBase, GroundServerBase)) and module.name == "Vim3"][0]
            
            self.s1v1_queue_len.append(len([req for req in vnf1.appLayer.req_processor.queue if vnf1.appLayer.req2msg[req].receiver == "Ue2App"]))
            self.s1v2_queue_len.append(len([req for req in vnf2.appLayer.req_processor.queue if vnf2.appLayer.req2msg[req].receiver == "Ue2App"]))
            self.s1v3_queue_len.append(len([req for req in vnf3.appLayer.req_processor.queue if vnf3.appLayer.req2msg[req].receiver == "Ue2App"]))
            
            value_actural = np.average(np.array(ue2.appLayer.performance_record[-20:]))
            
            value_Rff = 0 # self.s1_kff * (predict_delay - self.s1_aim_delay)
            
            value_err = value_actural - self.s1_aim_delay
            value_P = self.s1_kp * value_err
            value_I = 0 # s1_ki * value_err * s1_aim_delay + s1_i_val
            value_D = 0 # s1_kd * (s1_err_prev - value_err) / s1_aim_delay
            value_PIDff = value_P + value_I + value_D + value_Rff
            
            self.s1_i_val = value_I
            self.s1_err_prev = value_err
            
            self.s1v1_process_rate.append(max(self.s1_arrive_rate+0.001, self.s1v1_process_rate[-1] + value_PIDff))
            self.s1v2_process_rate.append(max(self.s1_arrive_rate+0.001, self.s1v2_process_rate[-1] + value_PIDff))
            self.s1v3_process_rate.append(max(self.s1_arrive_rate+0.001, self.s1v3_process_rate[-1] + value_PIDff))

            # 最终作用在各个 vnf 上
            
            vnf1.appLayer.process_rate = self.s1v1_process_rate[-1] + self.s2v1_process_rate[-1]
            vnf2.appLayer.process_rate = self.s1v2_process_rate[-1] + self.s2v2_process_rate[-1]
            vnf3.appLayer.process_rate = self.s1v3_process_rate[-1]
            
            self.s1_revenue_list.append(self.calculate_revenue(ue2.appLayer.performance_record, self.s1_aim_delay, self.s1_aim_overrate, self.s1_arrive_rate))
            self.s1_cost_list.append(self.s1v1_process_rate[-1] + self.s1v2_process_rate[-1] + self.s1v3_process_rate[-1])
            
            print(f"\n Current time: {self.scheduler.now} \t | s1_perform: {ue2.appLayer.performance_record[-1]} \n \
                | vnf1_rate: {vnf1.appLayer.process_rate} \t | vnf2_rate: {vnf2.appLayer.process_rate} \t | vnf3_rate: {vnf3.appLayer.process_rate} \n \
                | s1v1_rate: {self.s1v1_process_rate[-1]} \t | s1v2_rate: {self.s1v2_process_rate[-1]} \t | s1v3_rate: {self.s1v3_process_rate[-1]} \n \
                | s1v1_queue: {self.s1v1_queue_len[-1]} \t | s1v2_queue: {self.s1v2_queue_len[-1]} \t | s1v3_queue: {self.s1v3_queue_len[-1]}")
            
            print(f"Netcalculate: tau: {tau}  block: {block}  with rho: {min(self.s1v1_process_rate[-1],self.s1v2_process_rate[-1],self.s1v3_process_rate[-1])}")

            print(f"value_actural: {value_actural} | value_err: {value_err} | value_P: {value_P} | value_I: {value_I} | value_D: {value_D} | value_PIDff: {value_PIDff}")
            print("-"*10+"\n")


    def controller_process_sfc_2(self):
        while True:
            yield self.scheduler.timeout(self.s2_aim_delay)
            
            ue4:UeBase = [module for module in self.net.modules if isinstance(module, UeBase) if module.name == 'Ue4'][0]
            
            if len(ue4.appLayer.performance_record) == 0:
                continue
            
            tau, block = self.netcalculate(service_curve = {'rho':min(self.s2v1_process_rate[-1],self.s2v2_process_rate[-1]),'sigma':0},
                                           arrive_curve = {'rho':self.s2_arrive_rate,'sigma':0},
                                           aim_overrate = self.s2_aim_overrate, theta = 0.1, nodenum = 2)
            predict_delay = tau # + link5.delay + link6.delay
            self.s2_predict_delay.append(predict_delay)
            self.s2_actural_delay.append(ue4.appLayer.performance_record[-1])
            
            vnf1:VnfBase = [module.vnfList[0] for module in self.net.modules 
                            if isinstance(module, (UavServerBase, GroundServerBase)) and module.name == "Vim1"][0]
            vnf2:VnfBase = [module.vnfList[0] for module in self.net.modules 
                            if isinstance(module, (UavServerBase, GroundServerBase)) and module.name == "Vim2"][0]
            
            self.s2v1_queue_len.append(len([req for req in vnf1.appLayer.req_processor.queue if vnf1.appLayer.req2msg[req].receiver == "Ue4App"]))
            self.s2v2_queue_len.append(len([req for req in vnf2.appLayer.req_processor.queue if vnf2.appLayer.req2msg[req].receiver == "Ue4App"]))
            
            value_actural = np.average(np.array(ue4.appLayer.performance_record[-20:]))
            
            value_Rff = 0 # self.s2_kff * (predict_delay - self.s2_aim_delay)
            
            value_err = value_actural - self.s2_aim_delay
            value_P = self.s2_kp * value_err
            value_I = 0 # s2_ki * value_err * s2_aim_delay + s2_i_val
            value_D = 0 # s2_kd * (s2_err_prev - value_err) / s2_aim_delay
            value_PIDff = value_P + value_I + value_D + value_Rff
            
            self.s2_i_val = value_I
            self.s2_err_prev = value_err
            
            self.s2v1_process_rate.append(max(self.s2_arrive_rate+0.001, self.s2v1_process_rate[-1] + value_PIDff))
            self.s2v2_process_rate.append(max(self.s2_arrive_rate+0.001, self.s2v2_process_rate[-1] + value_PIDff))
            
            # 最终作用在各个 vnf 上
            
            vnf1.appLayer.process_rate = self.s2v1_process_rate[-1] + self.s1v1_process_rate[-1]
            vnf2.appLayer.process_rate = self.s2v2_process_rate[-1] + self.s1v2_process_rate[-1]
            
            self.s2_revenue_list.append(self.calculate_revenue(ue4.appLayer.performance_record, self.s2_aim_delay, self.s2_aim_overrate, self.s2_arrive_rate))
            self.s2_cost_list.append(self.s2v1_process_rate[-1] + self.s2v2_process_rate[-1])
            
            print(f"\n Current time: {self.scheduler.now} \t | s2_perform: {ue4.appLayer.performance_record[-1]} \n \
                | vnf1_rate: {vnf1.appLayer.process_rate} \t | vnf2_rate: {vnf2.appLayer.process_rate} \n \
                | s2v1_rate: {self.s2v1_process_rate[-1]} \t | s2v2_rate: {self.s2v2_process_rate[-1]} \n \
                | s2v1_queue: {self.s2v1_queue_len[-1]} \t | s2v2_queue: {self.s2v2_queue_len[-1]}")
            
            print(f"Netcalculate: tau: {tau}  block: {block}  with rho: {min(self.s2v1_process_rate[-1],self.s2v2_process_rate[-1])}")

            print(f"value_actural: {value_actural} | value_err: {value_err} | value_P: {value_P} | value_I: {value_I} | value_D: {value_D} | value_PIDff: {value_PIDff}")        
            print("-"*10+"\n")


    def draw_results(self):
        figure = plt.figure(figsize=(10, 5))

        plt.scatter(x=[i for i in range(len(self.s1_actural_delay))], y=self.s1_actural_delay, color="black", marker=".", s=10, label="SFC1")
        plt.scatter(x=[i for i in range(len(self.s2_actural_delay))], y=self.s2_actural_delay, color="red", marker=".", s=10, label="SFC2")
        # plt.plot([i for i in range(len(self.s1_predict_delay))], self.s1_predict_delay, color="black", label="SFC1_predict", linewidth=1)
        # plt.plot([i for i in range(len(self.s2_predict_delay))], self.s2_predict_delay, color="red", label="SFC2_predict", linewidth=1)
        plt.plot([i for i in range(len(self.s1_actural_delay))], [self.s1_aim_delay]*len(self.s1_actural_delay), linestyle="--", color="black", label="SFC1_aim")
        plt.plot([i for i in range(len(self.s2_actural_delay))], [self.s2_aim_delay]*len(self.s2_actural_delay), linestyle="--", color="red", label="SFC2_aim")

        plt.xlabel("Packet")
        plt.ylabel("Delay")
        plt.legend()
        plt.savefig(f"{self.net.name}_delay.png")

        figure = plt.figure(figsize=(10, 5))

        plt.plot([i for i in range(len(self.s1v1_queue_len))], self.s1v1_queue_len, label="s1v1_queue", marker=".", linestyle="-")
        plt.plot([i for i in range(len(self.s1v2_queue_len))], self.s1v2_queue_len, label="s1v2_queue", marker="o", linestyle="-")
        plt.plot([i for i in range(len(self.s1v3_queue_len))], self.s1v3_queue_len, label="s1v3_queue", marker="*", linestyle="-")
        plt.plot([i for i in range(len(self.s2v1_queue_len))], self.s2v1_queue_len, label="s2v1_queue", marker=".", linestyle="--")
        plt.plot([i for i in range(len(self.s2v2_queue_len))], self.s2v2_queue_len, label="s2v2_queue", marker="o", linestyle="--")

        plt.xlabel("Time")
        plt.ylabel("Queue Length")
        plt.legend()
        plt.savefig(f"{self.net.name}_queue.png")

        figure = plt.figure(figsize=(10, 5))

        plt.plot([i for i in range(len(self.s1v1_process_rate))], self.s1v1_process_rate, label="s1v1_process_rate", marker=".", linestyle="-")
        plt.plot([i for i in range(len(self.s1v2_process_rate))], self.s1v2_process_rate, label="s1v2_process_rate", marker="o", linestyle="-")
        plt.plot([i for i in range(len(self.s1v3_process_rate))], self.s1v3_process_rate, label="s1v3_process_rate", marker="*", linestyle="-")
        plt.plot([i for i in range(len(self.s2v1_process_rate))], self.s2v1_process_rate, label="s2v1_process_rate", marker=".", linestyle="--")
        plt.plot([i for i in range(len(self.s2v2_process_rate))], self.s2v2_process_rate, label="s2v2_process_rate", marker="o", linestyle="--")

        plt.xlabel("Time")
        plt.ylabel("Process Rate")
        plt.legend()
        plt.savefig(f"{self.net.name}_process_rate.png")

        # endregion

        # region 输出统计结果

        pkt_delay_sfc1 = []
        pkt_delay_sfc2 = []
        loggerItems:list[OLogItem] = self.logger.extract_log_items()

        for item in loggerItems:
            if item.event == "r" and item.to_node == "Ue2App":
                if item.pkt_delay != "*":
                    pkt_delay_sfc1.append(float(item.pkt_delay))
            elif item.event == "r" and item.to_node == "Ue4App":
                if item.pkt_delay != "*":
                    pkt_delay_sfc2.append(float(item.pkt_delay))

        over_delay_num_sfc1 = 0
        over_delay_num_sfc2 = 0

        for delay in pkt_delay_sfc1:
            if delay > self.s1_aim_delay:
                over_delay_num_sfc1 += 1
                
        for delay in pkt_delay_sfc2:
            if delay > self.s2_aim_delay:
                over_delay_num_sfc2 += 1

        if len(pkt_delay_sfc1) != 0 and len(pkt_delay_sfc2) != 0:
            print(f"s1_over_delay_rate: {over_delay_num_sfc1/len(pkt_delay_sfc1)}")
            print(f"s2_over_delay_rate: {over_delay_num_sfc2/len(pkt_delay_sfc2)}")


        # endregion

        # region 输出服务功能链的收支比

        # print("sfc_1_revenue_list:",sfc_1_revenue_list)
        # print("sfc_1_cost_list:",sfc_1_cost_list)

        if sum(self.s1_cost_list) != 0 and sum(self.s2_cost_list) != 0:
            print("s1_rev_cost_rate:",sum(self.s1_revenue_list)/sum(self.s1_cost_list))
            print("s2_rev_cost_rate:",sum(self.s2_revenue_list)/sum(self.s2_cost_list))

        # endregion